In [1]:
import torch
import matplotlib.pyplot as plt
import matplotlib.animation as animation

from shapely import geometry, affinity
from IPython.display import HTML
In [2]:
class DistanceOverUnionWithShapely(torch.autograd.Function):
    @staticmethod
    def forward(ctx, parameters, target_rect):
        """compute diou loss using shapely geometric operations
        """
        
        x, y, w, h, t = parameters.detach().numpy()
        
        rect = geometry.box(-w / 2, -h / 2, w / 2, h / 2)
        rect = affinity.translate(rect, xoff=x, yoff=y)
        rect = affinity.rotate(rect, t, use_radians=True)
        
        intersection = rect.intersection(target_rect)
        union = rect.union(target_rect)
        iou = intersection.area / union.area
        distance = rect.centroid.distance(target_rect.centroid)
        bbox = geometry.GeometryCollection([target_rect, rect]).envelope
        diag = geometry.Point(bbox.bounds[:2]).distance(geometry.Point(bbox.bounds[2:]))
        
        loss = 1 - iou + (distance  ** 2) / (diag ** 2) 
        
        ctx.save_for_backward(parameters)
        ctx.target_rect = target_rect
        
        return torch.tensor(loss, dtype=torch.float32)
    
    @staticmethod
    def backward(ctx, grad_output):
        """compute gradients by central numerical differentiation
        """
        
        parameters, = ctx.saved_tensors
        target_rect = ctx.target_rect

        eps = 1e-4
        grads = []
        for i in range(len(parameters)):
            delta = torch.zeros_like(parameters)
            delta[i] = eps
            
            perturbed_params_pos = parameters + delta
            ppp_x, ppp_y, ppp_w, ppp_h, ppp_t = perturbed_params_pos.detach().numpy()
            
            rect_pos = geometry.box(-ppp_w / 2, -ppp_h / 2, ppp_w / 2, ppp_h / 2)
            rect_pos = affinity.translate(rect_pos, xoff=ppp_x, yoff=ppp_y)
            rect_pos = affinity.rotate(rect_pos, ppp_t, use_radians=True)
            intersection_pos = rect_pos.intersection(target_rect)
            union_pos = rect_pos.union(target_rect)
            iou_pos = intersection_pos.area / union_pos.area
            distance_pos = target_rect.centroid.distance(rect_pos.centroid)
            bbox_pos = geometry.GeometryCollection([target_rect, rect_pos]).envelope
            diag_pos = geometry.Point(bbox_pos.bounds[:2]).distance(geometry.Point(bbox_pos.bounds[2:]))

            grad_pos = 1 - iou_pos + (distance_pos  ** 2) / (diag_pos ** 2) 
            
            perturbed_params_neg = parameters - delta
            ppn_x, ppn_y, ppn_w, ppn_h, ppn_t = perturbed_params_neg.detach().numpy()

            rect_neg = geometry.box(-ppn_w / 2, -ppn_h / 2, ppn_w / 2, ppn_h / 2)
            rect_neg = affinity.translate(rect_neg, xoff=ppn_x, yoff=ppn_y)
            rect_neg = affinity.rotate(rect_neg, ppn_t, use_radians=True)
            intersection_neg = rect_neg.intersection(target_rect)
            union_neg = rect_neg.union(target_rect)
            iou_neg = intersection_neg.area / union_neg.area
            distance_neg = target_rect.centroid.distance(rect_neg.centroid)
            bbox_neg = geometry.GeometryCollection([target_rect, rect_neg]).envelope
            diag_neg = geometry.Point(bbox_neg.bounds[:2]).distance(geometry.Point(bbox_neg.bounds[2:]))

            grad_neg = 1 - iou_neg + (distance_neg  ** 2) / (diag_neg ** 2) 
            
            grad = (grad_pos - grad_neg) / (2 * eps)
            grads.append(grad)
        
        grads = torch.tensor(grads, dtype=torch.float32)
        
        return grads * grad_output, None
In [3]:
parameters = torch.tensor([-3, -3, 2.0, 1.0, 0], requires_grad=True)
optimizer = torch.optim.Adam([parameters], lr=0.01)

target_rect = geometry.box(-2, -2, 1, 2)
target_rect = affinity.translate(target_rect, *(2, 2))
target_rect = affinity.rotate(target_rect, 45, use_radians=False)

losses = []
rects = []

for epoch in range(1000):
    x, y, w, h, t = parameters.detach().numpy()
    rect = geometry.box(-w / 2, -h / 2, w / 2, h / 2)
    rect = affinity.translate(rect, xoff=x, yoff=y)
    rect = affinity.rotate(rect, t, use_radians=True)
    
    optimizer.zero_grad()
    loss = DistanceOverUnionWithShapely.apply(parameters, target_rect)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    rects.append(rect)
    
    if loss.item() < 1e-3:
        break
In [4]:
fig, ax = plt.subplots()

def animate(frame):
    rect = rects[frame]
    distance = geometry.LineString([rect.centroid, target_rect.centroid])
    
    ax.clear()
    ax.set_xlim(-10, 10)
    ax.set_ylim(-10, 10)
    ax.grid(True, alpha=0.3)
    ax.plot(*target_rect.exterior.xy, color="green", label="ground-truth", linewidth=1)
    ax.plot(*rect.exterior.xy, color="black", label="predicted", linewidth=1)
    ax.plot(*distance.xy, color="red", label="distance", linewidth=1, linestyle="dotted")
    ax.set_aspect("equal")
    ax.legend(loc="upper left")
    ax.set_title(f"Epoch: {frame + 1}, DIoU: {losses[frame]:.5f} \n", fontsize=9)
    
    return ()

anim = animation.FuncAnimation(fig, animate, frames=len(rects), interval=50, blit=True, repeat=False)
plt.close(fig)
HTML(anim.to_jshtml())
Out[4]: